# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

name: Batch Prediction
inputs:
- {name: project_id, type: String}
- {name: data_region, type: String}
- {name: data_pipeline_root, type: String}
- {name: gcs_result_folder, type: String}
- {name: instances_format, type: String}
- {name: predictions_format, type: String}
- {name: model_resource_name, type: String}
- {name: endpoint_resource_name, type: String}
- {name: machine_type, type: String}
- {name: accelerator_type, type: String}
- {name: accelerator_count, type: Integer}
- {name: max_replica_count, type: Integer}
- {name: starting_replica_count, type: Integer}
- {name: input_dataset, type: Dataset}
outputs:
- {name: prediction_result, type: Dataset}
implementation:
  container:
    image: {{af_registry_location}}-docker.pkg.dev/{{project_id}}/{{af_registry_name}}/component-base:latest
    command: [python, /pipelines/component/src/predict.py]
    args: [
      --executor_input, {executorInput: null},
      --function_to_execute, batch_prediction,
      --project-id, {inputValue: project_id},
      --data-region, {inputValue: data_region},
      --data-pipeline-root, {inputValue: data_pipeline_root},
      --gcs-result-folder, {inputValue: gcs_result_folder},
      --instances-format, {inputValue: instances_format},
      --predictions-format, {inputValue: predictions_format},
      --model-resource-name, {inputValue: model_resource_name},
      --endpoint-resource-name, {inputValue: endpoint_resource_name},
      --machine-type, {inputValue: machine_type},
      --accelerator-type, {inputValue: accelerator_type},
      --accelerator-count, {inputValue: accelerator_count},
      --starting-replica-count-count, {inputValue: starting_replica_count},
      --max-replica-count, {inputValue: max_replica_count},
      --input-dataset, {inputPath: input_dataset},
      --prediction-result, {outputPath: prediction_result}
    ]
